import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

LOG_SIG_MAX = 2
LOG_SIG_MIN = -10
epsilon = 1e-6
# scale = 4
# Initialize Policy weights
def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)


class ValueNetwork(nn.Module):
    def __init__(self, num_inputs, hidden_dim):
        super(ValueNetwork, self).__init__()

        self.encoder = nn.Linear(num_inputs, hidden_dim)
        self.block = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),  
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
        )
        self.decoder = nn.Linear(hidden_dim, 1)

        self.apply(weights_init_)

    def forward(self, state):
        x = self.encoder(state)
        x = self.block(x)
        x = self.decoder(x)
        return x


class QNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim):
        super(QNetwork, self).__init__()
        
        def create_q_block():
            return nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),  
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            )

        
        self.Q1_encoder = nn.Linear(num_inputs + num_actions, hidden_dim)
        self.Q1_block = create_q_block()
        self.Q1_decoder = nn.Linear(hidden_dim,1)
        
        self.Q2_encoder = nn.Linear(num_inputs + num_actions, hidden_dim)
        self.Q2_block = create_q_block()
        self.Q2_decoder = nn.Linear(hidden_dim,1)
        self.apply(weights_init_)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
    
        x1 = self.Q1_encoder(x)
        x1 = self.Q1_block(x1)
        q_value_1 = self.Q1_decoder(x1)
        
        x2 = self.Q2_encoder(x)
        x2 = self.Q2_block(x2)
        q_value_2 = self.Q2_decoder(x2)
        
        return q_value_1, q_value_2
